// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "gtest/gtest.h"
#include "test/common/AcceleratorHarness.h"
#include "test/common/GoldModel.h"

TEST(Simple, Transpose) {
  std::cout << "Transpose Test" << std::endl;
  std::cout << "--------------" << std::endl;

  INPUT_DATATYPE *mainMemory = new INPUT_DATATYPE[4 * 1024 * 1024];

  const Params params = {
      4,      // M0
      1,      // P1
      1,      // N1
      1,      // M1
      1,      // P2
      0,      // INPUT_OFFSET
      1024,   // WEIGHT_OFFSET
      2048,   // OUTPUT_OFFSET
      false,  // SOFTMAX
      1,      // SCALE
      true,   // TRANSPOSE
      0,      // VECTOR_OFFSET
      false,  // VEC_OP
      false,  // VEC_SUB
      false,  // VEC_SQUARE
      false,  // VEC_REDUCE
      true,   // CONST_SCALE
      0,      // VEC_SCALE_OFFSET
      0       // VEC_SUB_OFFSET
  };

  // Create matrix A
  INPUT_DATATYPE *matrixA =
      new INPUT_DATATYPE[params.M0 * params.M1 * params.N1 * DIMENSION];
  for (int i = 0; i < params.M0 * params.M1; i++) {
    for (int j = 0; j < params.N1 * DIMENSION; j++) {
      int val = rand() % 128;

      mainMemory[params.INPUT_OFFSET + i * (params.N1 * DIMENSION) + j] =
      val; matrixA[i * (params.N1 * DIMENSION) + j] = val;
    }
  }

  // Create matrix B
  INPUT_DATATYPE *matrixB =
      new INPUT_DATATYPE[params.N1 * DIMENSION * params.P1 * params.P2 *
                         DIMENSION];
  for (int i = 0; i < params.N1 * DIMENSION; i++) {
    for (int j = 0; j < params.P1 * params.P2 * DIMENSION; j++) {
      int val = rand() % 128;

      mainMemory[params.WEIGHT_OFFSET +
                 i * (params.P1 * params.P2 * DIMENSION) + j] = val;
      matrixB[i * (params.P1 * params.P2 * DIMENSION) + j] = val;
    }
  }

  // Create matrix C
  OUTPUT_DATATYPE *matrixC =
      new OUTPUT_DATATYPE[params.M0 * params.M1 * params.P1 * params.P2 *
                          DIMENSION];

  run_op(params, mainMemory);
  run_gold_op(params, matrixA, matrixB, matrixC);
  check_outputs(params, mainMemory, matrixC);

  delete[] matrixA;
  delete[] matrixB;
  delete[] matrixC;
  delete[] mainMemory;
}

int sc_main(int argc, char *argv[]) {
  ::testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}
